In [ ]:
# REQUIREMENTS
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import os, json

# solve the font issue
# ref: https://blog.csdn.net/fwj_ntu/article/details/105598145
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] 
In [ ]:
# read files
dataset_path = os.path.join(os.getcwd(), 'datasets')
dataset_name = 'diakg'
dataset_path = os.path.join(dataset_path, dataset_name)
diakg = []

for file in os.listdir(dataset_path):
    if not file.endswith('.json'):
        # avoid DS_Store
        continue
    file_path = os.path.join(dataset_path, file)
    with open(file_path, 'r', encoding='utf-8') as f:
        diakg.append(json.load(f))
In [ ]:
# TOOLS

def remove_none_label(l: dict) -> dict:
    # remove None label in labels
    mark = []

    for k, v in l.items():
        if v == None:
            mark.append(k)

    for m in mark:
        l.pop(m)

    return l

def get_node_centric_graph(df, node_name):
    """
    :df: data frame
    :node_name: the center node name
    """

    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] 

    sources, targets, relations = [], [], []

    for i in range(len(df['source'])):
        if df.loc[i]['source'] == node_name:
            sources.append(node_name)
            targets.append(df.loc[i]['target'])
            relations.append(df.loc[i]['edge'])
        elif isinstance(df.loc[i]['source'], tuple):
            if df.loc[i]['source'][0] == node_name:
                sources.append(df.loc[i]['source'])
                targets.append(df.loc[i]['target'])
                relations.append(df.loc[i]['edge'])
    
    new_df = pd.DataFrame({'source': sources, 'target': targets, 'edge': relations})

    return new_df

def draw_graph(df):
    """
    draw graph given a triple data frame
    """

    # create a directed-graph from a dataframe
    G=nx.from_pandas_edgelist(t, source = 'source', target = 'target', edge_attr = 'edge', create_using= nx.DiGraph())

    plt.figure(figsize=(100,100))
    pos = nx.spring_layout(G)
    edge_labels = remove_none_label(nx.get_edge_attributes(G, 'edge'))
    # remove coordinator nodes 
    node_labels = {n: n for n in G if not isinstance(n, tuple)}
    # remove coordinator nodes 
    node_list = [i for i in list(G) if not isinstance(i, tuple)]
    # draw nodes
    nx.draw_networkx_nodes(G, pos, nodelist=node_list)
    # draw all edges
    nx.draw_networkx_edges(G, pos)
    # draw labels for all desired nodes
    nx.draw_networkx_labels(G, pos, labels=node_labels)
    # draw labels for all desired node-node edges
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    plt.show()

    return
In [ ]:
# construct triples
sources, targets, relations = [], [], []
entity_map, relation_map = dict(), dict()
used_entity = set()

for doc in diakg:
    for paragraph in doc['paragraphs']:
        for sentence in paragraph['sentences']:
            for entity in sentence['entities']:
                if entity['entity_id'] not in entity_map.keys():
                    entity_map[entity['entity_id']] = entity['entity']
            for relation in sentence['relations']:
                source, target = entity_map[relation['tail_entity_id']], entity_map[relation['head_entity_id']]
                # some entity doesn't have a relation with others
                used_entity.add(source)
                used_entity.add(target)
                relation_type = relation['relation_type']
                mid_target = (source, relation_type)
                sources.append(source)
                targets.append(mid_target)
                relations.append(relation_type)
                sources.append(mid_target)
                targets.append(target)
                relations.append(None)
                if relation['relation_id'] not in relation_map.keys():
                    relation_map[relation['relation_id']] = relation['relation_type']

df = pd.DataFrame({'source': sources, 'target': targets, 'edge': relations})
df
Out[ ]:
source target edge
0 2型糖尿病 (2型糖尿病, Class_Disease) Class_Disease
1 (2型糖尿病, Class_Disease) 2型 None
2 2型糖尿病 (2型糖尿病, Treatment_Disease) Treatment_Disease
3 (2型糖尿病, Treatment_Disease) 自我管理 None
4 2型糖尿病 (2型糖尿病, Class_Disease) Class_Disease
... ... ... ...
17281 (糖尿病, Drug_Disease) DSME/S None
17282 血糖监测 (血糖监测, Class_Disease) Class_Disease
17283 (血糖监测, Class_Disease) 降低 None
17284 血糖控制 (血糖控制, Class_Disease) Class_Disease
17285 (血糖控制, Class_Disease) 糖尿病 None

17286 rows × 3 columns

In [ ]:
t = get_node_centric_graph(df, '糖尿病')
t
draw_graph(t)
Out[ ]:
source target edge
0 糖尿病 (糖尿病, Treatment_Disease) Treatment_Disease
1 (糖尿病, Treatment_Disease) 自身管理 None
2 糖尿病 (糖尿病, Treatment_Disease) Treatment_Disease
3 (糖尿病, Treatment_Disease) 糖尿病自我管理教育和支持(diabetes self-management educatio... None
4 糖尿病 (糖尿病, Treatment_Disease) Treatment_Disease
... ... ... ...
1845 (糖尿病, Drug_Disease) 糖尿病教育 None
1846 糖尿病 (糖尿病, Drug_Disease) Drug_Disease
1847 (糖尿病, Drug_Disease) 心理健康咨询 None
1848 糖尿病 (糖尿病, Drug_Disease) Drug_Disease
1849 (糖尿病, Drug_Disease) DSME/S None

1850 rows × 3 columns